// Copyright 2014 Google Inc. All Rights Reserved.

#include "SslWrapper.h"

#if defined(__MINGW32__)
#include <winsock2.h>
#endif

#include <openssl/bio.h>
#include <openssl/conf.h>
#include <openssl/engine.h>
#include <openssl/err.h>
#include <openssl/rand.h>
#include <openssl/ssl.h>

// Unavoidable, unfortunately. This is a dirty dirty hack.
static Mutex* sslGlobalMutexInit() {
    return new Mutex[CRYPTO_num_locks()];
}
static Mutex* sslMutexes = sslGlobalMutexInit();

static void sslLockingFunction(int mode, int n, const char* /* file */, int /* line */) {
    assert(sslMutexes != NULL);
    if (mode & CRYPTO_LOCK) {
        sslMutexes[n].lock();
    } else {
        sslMutexes[n].unlock();
    }
}

struct SslWrapperState {
    SslWrapperState()
      : mRootCert(NULL),
        mClientCert(NULL),
        mRbio(NULL),
        mWbio(NULL),
        mPrivateKey(NULL),
        mSslCtx(NULL),
        mSsl(NULL),
        mX509Store(NULL) {}

    X509* mRootCert;
    X509* mClientCert;
    BIO* mRbio; // Calls to SSL_read use this as input.
    BIO* mWbio; // Calls to SSL_write use this as output.
    EVP_PKEY* mPrivateKey;
    SSL_CTX* mSslCtx;
    SSL* mSsl;
    X509_STORE* mX509Store;
    Mutex mSslLock;
};

SslWrapper::SslWrapper()
    : mAuthState(AUTHENTICATION_UNINITIALIZED),
      mState(new SslWrapperState),
      mHasOverrideTime(false),
      mOverrideTime(0) {}

SslWrapper::~SslWrapper() {
    delete mState;
}

bool SslWrapper::init(const string& rootCert, const string& clientCert, const string& clientKey) {
    SSL_load_error_strings();
    SSL_library_init();
    OpenSSL_add_all_algorithms();

    if (RAND_status() != 1) {
        LOGE("No /dev/urandom - Must seed RNG manually.");
        return false;
    }

    // OpenSSL needs this to be thread safe.
    CRYPTO_set_locking_callback(sslLockingFunction);

    BIO* bio = BIO_new_mem_buf((void*)rootCert.c_str(), rootCert.size());
    mState->mRootCert = PEM_read_bio_X509(bio, NULL, NULL, NULL);
    BIO_free(bio);

    bio = BIO_new_mem_buf((void*)clientCert.c_str(), clientCert.size());
    mState->mClientCert = PEM_read_bio_X509(bio, NULL, NULL, NULL);
    BIO_free(bio);

    bio = BIO_new_mem_buf((void*)clientKey.c_str(), clientKey.size());
    mState->mPrivateKey = PEM_read_bio_PrivateKey(bio, NULL, NULL, NULL);
    BIO_free(bio);

    mState->mSslCtx = SSL_CTX_new(TLSv1_2_client_method());
    if (mState->mSslCtx == NULL) {
        LOGE("Failed to allocate SSL_CTX.");
        return false;
    }
    if (SSL_CTX_use_certificate(mState->mSslCtx, mState->mClientCert) != 1) {
        LOGE("Failed to set client certificate.");
        return false;
    }
    if (SSL_CTX_use_PrivateKey(mState->mSslCtx, mState->mPrivateKey) != 1) {
        LOGE("Failed to set private key.");
        return false;
    }
    long options = SSL_OP_NO_TLSv1 | SSL_OP_NO_SSLv3 | SSL_OP_NO_SSLv2
            | SSL_OP_NO_COMPRESSION /* CRIME attack */
            | SSL_OP_SINGLE_DH_USE | SSL_OP_SINGLE_ECDH_USE; /* small subgroup attack */
    SSL_CTX_set_options(mState->mSslCtx, options);

    mState->mSsl = SSL_new(mState->mSslCtx);
    if (mState->mSsl == NULL) {
        LOGE("Failed to allocate SSL.");
        return false;
    }
    if (SSL_check_private_key(mState->mSsl) != 1) {
        LOGE("SSL check private key failed!");
        return false;
    }
    mState->mRbio = BIO_new(BIO_s_mem());
    if (mState->mRbio == NULL) {
        LOGE("Failed to allocate read bio.");
        return false;
    }
    mState->mWbio = BIO_new(BIO_s_mem());
    if (mState->mWbio == NULL) {
        LOGE("Failed to allocate write bio.");
        return false;
    }
    SSL_set_bio(mState->mSsl, mState->mRbio, mState->mWbio);
    SSL_set_connect_state(mState->mSsl);

    // Set up a X509 store to verify peer certificates.
    mState->mX509Store = X509_STORE_new();
    if (mState->mX509Store == NULL) {
        LOGE("Failed to allocate x509 store.");
        return false;
    }
    if (X509_STORE_add_cert(mState->mX509Store, mState->mRootCert) != 1) {
        LOGE("Failed to add root certificate.");
        return false;
    }
    X509_STORE_set_flags(mState->mX509Store, 0);
    mAuthState = AUTHENTICATION_NOT_STARTED;
    return true;
}

void SslWrapper::shutdown() {
    if (mState->mSsl != NULL) {
        SSL_free(mState->mSsl);
    } else {
        if (mState->mRbio != NULL) {
            BIO_free(mState->mRbio);
        }
        if (mState->mWbio != NULL) {
            BIO_free(mState->mWbio);
        }
    }
    if (mState->mSslCtx != NULL) {
        SSL_CTX_free(mState->mSslCtx);
    }
    if (mState->mX509Store != NULL) {
        X509_STORE_free(mState->mX509Store);
    }
    if (mState->mRootCert != NULL) {
        X509_free(mState->mRootCert);
    }
    if (mState->mClientCert != NULL) {
        X509_free(mState->mClientCert);
    }
    if (mState->mPrivateKey != NULL) {
        EVP_PKEY_free(mState->mPrivateKey);
    }

#if !defined(OPENSSL_IS_BORINGSSL)
    CONF_modules_free();
    ENGINE_cleanup();
    CONF_modules_unload(1);
    EVP_cleanup();
    CRYPTO_cleanup_all_ex_data();
#endif

    ERR_remove_thread_state(NULL);
}

// logErrorQueueCallback is a callback function that is called by
// |ERR_print_errors_cb| for each error in the current thread's error queue.
static int logErrorQueueCallback(const char *str, size_t len, void *ctx) {
    LOGE("  %s", str);
    return 1;
}

// logErrorQueue logs each entry in the OpenSSL error queue for the current
// thread using LOGE. It then clears the queue.
static void logErrorQueue() {
    ERR_print_errors_cb(logErrorQueueCallback, NULL);
    ERR_clear_error();
}

int SslWrapper::handshake(const void* in, size_t len, IoBuffer* out) {
    if (in != NULL) {
        BIO_write(mState->mRbio, in, len);
    }

    int ret = SSL_do_handshake(mState->mSsl);
    int err = SSL_get_error(mState->mSsl, ret);
    LOG("ssl state=%s %d", SSL_state_string_long(mState->mSsl), ret);
    ERR_remove_thread_state(NULL);

    if (ret == 1) {
        mAuthState = AUTHENTICATION_SUCCESS;
        LOG("SSL version=%s Cipher name=%s", SSL_get_version(mState->mSsl), SSL_get_cipher(mState->mSsl));
    } else if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
        int len = BIO_pending(mState->mWbio);
        if (len > 0) {
            out->resize(len + sizeof(uint16_t));
            WRITE_BE16(out->raw(), MESSAGE_ENCAPSULATED_SSL);
            BIO_read(mState->mWbio, (uint8_t*)out->raw() + sizeof(uint16_t), len);
        }
    } else {
        mAuthState = AUTHENTICATION_FAILURE;
        logErrorQueue();
    }
    return ret;
}

bool SslWrapper::verifyPeer() {
    X509_STORE_CTX *ctx = X509_STORE_CTX_new();
    X509* cert = SSL_get_peer_certificate(mState->mSsl);
    STACK_OF(X509)* chain = SSL_get_peer_cert_chain(mState->mSsl);
    X509_STORE_CTX_init(ctx, mState->mX509Store, cert, NULL);
    if (chain != NULL) {
        X509_STORE_CTX_set_chain(ctx, chain);
    }

    // Allow OEM's who don't have wall time on their head unit to manually set the time.
    // usually obtained from some other part of the system.
    if (mHasOverrideTime) {
        X509_STORE_CTX_set_time(ctx, 0, mOverrideTime);  // flags is ignored.
    }

    int ret = X509_verify_cert(ctx);
    int err = X509_STORE_CTX_get_error(ctx);
    LOG("Verify returned: %s", X509_verify_cert_error_string(err));
    X509_STORE_CTX_cleanup(ctx);
    X509_STORE_CTX_free(ctx);
    X509_free(cert);
#if NO_SSL_CERT_DATE_CHECK
    return (err == X509_V_OK) || (err == X509_V_ERR_CERT_NOT_YET_VALID);
#else
    return ret == 1;
#endif
}

int SslWrapper::encryptionPipelineEnqueue(void* data, size_t len) {
    mState->mSslLock.lock();
    int ret = SSL_write(mState->mSsl, data, len);
    mState->mSslLock.unlock();

    if (ret <= 0) {
        // It's assumed that, during a normal connection, SSL_ERROR_WANT_READ
        // and SSL_ERROR_WANT_WRITE don't occur because the handshake has
        // completed.
        LOGE("Failed to write data to mState->mSsl!");
        return -1;
    }
    return BIO_pending(mState->mWbio);
}

int SslWrapper::encryptionPipelineDequeue(void* data, size_t len) {
    int ret = BIO_read(mState->mWbio, data, len);
    if (ret < 0) {
        LOGE("Failed to read from encryption pipeline.");
    }
    return ret;
}

bool SslWrapper::decryptionPipelineEnqueue(void* data, size_t len) {
    int ret = BIO_write(mState->mRbio, data, len);
    if (ret != (int) len) {
        LOGE("Failed to write data to the read bio!");
        return false;
    }
    return true;
}

int SslWrapper::decryptionPipelineDequeue(void* data, size_t len) {
    mState->mSslLock.lock();
    int ret = SSL_read(mState->mSsl, data, len);
    if (ret >= 0) {
        mState->mSslLock.unlock();
        return ret;
    }

    int sslErr = SSL_get_error(mState->mSsl, ret);
    mState->mSslLock.unlock();
    if (sslErr == SSL_ERROR_WANT_READ) {
        return 0;
    }

    LOGE("Error from SSL_read (ret=%d, sslErr=%d):", ret, sslErr);
    logErrorQueue();
    return -1;
}
